
import torch
import torch.nn as nn
import torch.nn.functional as F

class GS_Parametrization(nn.Module):
    def __init__(self, num_points_in_patch, init_scale=0.05):
        super().__init__()
        self.num_points_in_patch = num_points_in_patch
        self.init_scale = init_scale

        # MLP
        self.scale_mlp = nn.Sequential(
            nn.Linear(3, 32),
            nn.ReLU(),
            nn.Linear(32, 3)  # 
        )

    def _init_covariance(self, positions):
        """

        """
        centroid = positions.mean(dim=1, keepdim=True)  #
        centered = positions - centroid  # 
        cov = torch.bmm(centered.transpose(1, 2), centered) / positions.shape[1]  # 

        # 
        normal = F.normalize(centered.mean(dim=1), p=2, dim=-1)  # 
        cov = cov + torch.matmul(normal.unsqueeze(-1), normal.unsqueeze(-2)) * 0.1  # 

        U, S, V = torch.svd(cov)  # 
        return U * self.init_scale  # 

    def forward(self, patches, valid_mask):
        """
       
        """
        valid_patches = patches[valid_mask]  # 

        # 
        positions = valid_patches[..., :3].mean(dim=1)  # 
        geometric_cov = self._init_covariance(valid_patches[..., :3])  # 

        # 
        learned_cov = self.scale_mlp(valid_patches[..., 3:].mean(dim=1))  # 

        #
        rotation = geometric_cov + 0.1 * learned_cov.unsqueeze(-1)  # [M_valid, 3, 3]

        # 
        colors = valid_patches[..., 3:6].mean(dim=1)  # [M_valid, 3]

        # 
        density = torch.full((valid_patches.shape[0], 1), valid_patches.shape[1] / self.num_points_in_patch, device=valid_patches.device)
        opacity = torch.sigmoid(density * 0.5)  # [M_valid, 1]

        # 
        rotation_flat = rotation.reshape(valid_patches.shape[0], 9)[:, :6]  # [M_valid, 6]

        params = torch.cat([positions, rotation_flat, colors, opacity], dim=1)  # [M_valid, 13]

        # 
        full_params = torch.zeros(valid_mask.shape[0], 13).to(params.device)
        full_params[valid_mask] = params

        return full_params  # [M, 13]



